package com.carol.bigdata.task.model.feature

import com.alibaba.fastjson.JSON
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import com.carol.bigdata.utils.{FuncUtil, HBaseUtil, TimeUtil}
import org.apache.spark.ml.feature.{FeatureHasher, StandardScaler, StandardScalerModel, VectorAssembler, VectorDisassembler}

import scala.collection.mutable

object FeatureUtil {

    val defaultLabelIndexMap: Map[String, Int] = Map("0" -> 0, "1" -> 1)

    /**
     * 获取标签表的合并数据 (mergeLabelRDD: List(statDay,account_id,game_id), Map(label1->index,label2->index,...))
     *
     * @param lTable       : 标签表
     * @param lcfList      : 标签表需要的cf List
     * @param lKeyColumn   : 标签表对应的key column
     * @param labelColumns : 标签列
     * @param labelList    : 对应标签的List
     * @param pattern      : rowkey正则匹配 ".*_game_id_.*" 扫描某个游戏的数据
     * @return labelRDD    : List(time,uid,game_id),List(1,0,1,0)
     */
    def calLabelRDD(hbaseParams: Map[String, String],
                    spark: SparkSession,
                    lTable: String,
                    lcfList: List[String],
                    lKeyColumn: List[String],
                    labelColumns: List[String],
                    labelList: List[List[String]],
                    pattern: String = ".*"
                   ): RDD[(List[String], mutable.Map[String, String])] = {
        val labelIndexMapList = FuncUtil.labelList2IndexMap(labelList)
        val labelRDD: RDD[(List[String], mutable.Map[String, String])] = HBaseUtil
          .readCFAllAsMap(hbaseParams, spark, lTable, lcfList, List(pattern), filterMode = "PATTERN")
          .map(x => {
              var keyList = List[String]()
              for (k <- lKeyColumn) {
                  keyList :+= x.getOrElse(k, "NULL")
                  x -= k
              }
              for ((c, labelIndexMap) <- labelColumns.zip(labelIndexMapList)) {
                  val label: String = x.getOrElse(c, "0")
                  val labelIndex: String = labelIndexMap.getOrElse(label, "0")
                  x.clear()
                  x.put(c + "_str", label)
                  x.put(c, labelIndex)
                  // valueList :+= util.Random.nextInt(2) // 测试用
              }
              (keyList, x)
          })

        labelRDD
    }




    /**
     * 获取某个游戏用户画像数据 (Key, features)
     *
     * @param uTable      : 用户画像表
     * @param ucfList     : 用户画像表cfList
     * @param uKeyColumn  : 用户画像表key
     * @param pattern     : rowkey正则匹配 ".*_game_id_.*" 扫描某个游戏的数据
     * @param filterMode  : 采用正则匹配和时间模式 List(pattern, [startDay, endDay])
     * @return userProfileRDD: (featRDD: List(statDay,account_id,game_id), Map(feat1->tag,...))
     */
    def calUserProfileRDD(hbaseParams: Map[String, String],
                          spark: SparkSession,
                          uTable: String,
                          ucfList: List[String],
                          uKeyColumn: List[String],
                          numColumn: List[String],
                          mapColumn: List[String],
                          pattern: String = ".*",
                          filterMode: String = "PATTERN AND TIME"
                         ): RDD[(List[String], mutable.Map[String, String])] = {
        // 获取用户画像表
        val userProfileRDD: RDD[(List[String], mutable.Map[String, String])] = HBaseUtil
          .readCFAllAsMap(hbaseParams, spark, uTable, ucfList, List(pattern), filterMode)
          .map(x => {
              var keyList = List[String]()
              for (k <- uKeyColumn) {
                  keyList :+= x.getOrElse(k, "NULL")
                  x -= k
              }
              for (numCol <- numColumn) {
                  x.put(numCol, FuncUtil.int2tag(x.getOrElse(numCol, 0).toString.toInt))
              }
              for (mapCol<- mapColumn) {
                  val mapValue: Map[String, Int] = FuncUtil.jsonObj2MapInt(JSON.parseObject(x.getOrElse(mapCol, null)))
                  x.put(mapCol, FuncUtil.mapMax2tag(mapValue))
              }
              (keyList, x)
          })

        userProfileRDD
    }

    /**
     * 将tag转化为int,并选择特征列和标签列获取用于模型训练或预测的DataFrame集
     *
     * @param featRDD          : 包含所有特征列
     * @param labelRDD         : 包含所有标签列
     * @param keyColumns       : key列
     * @param doubleFeatColumn : double特征列
     * @param strFeatColumns   : str特征列
     * @param labelColumns     : 标签特征列
     * @param numFeatures      : 特征数量
     * @return hashFeatureDFList:  doubleFeatColumn ::: strFeatColumns :+ label1/2/3..
     */
    def calHashFeatureLabelDF(spark: SparkSession,
                              featRDD: RDD[(List[String], mutable.Map[String, String])],
                              labelRDD: RDD[(List[String], mutable.Map[String, String])],
                              keyColumns: List[String],
                              doubleFeatColumn: List[String],
                              strFeatColumns: List[String],
                              labelColumns: List[String],
                              featureCol: String = "features",
                              labelCol: String = "label",
                              normalize: Boolean = false,
                              normalizeModelPath: String = "standard-models",
                              numFeatures: Int = 262144): List[DataFrame] = {


        val featureLabelRDD: RDD[List[Any]] = transformDataType(
            labelRDD.join(featRDD).map(x => (x._1, x._2._1 ++ x._2._2)),
            doubleFeatColumn, strFeatColumns, labelColumns)
        featureLabelRDD.take(10).foreach(println)

        val schema = StructType(
            (keyColumns ::: strFeatColumns).map(i => StructField(i, StringType))
              ::: doubleFeatColumn.map(i => StructField(i, DoubleType))
              ::: labelColumns.map(i => StructField(i, IntegerType))
        )
        var dataset = spark.createDataFrame(featureLabelRDD.map(x => Row.fromSeq(x)), schema)
        dataset.show(10)

        var doubleFeatColumnList: List[String] = doubleFeatColumn
        if (normalize) {
            val (disassemblerDataset, doubleFeatColumns) = normalizeDF(dataset, doubleFeatColumnList, normalizeModelPath)
            dataset = disassemblerDataset
            doubleFeatColumnList = doubleFeatColumns
        }

        // 将特征划分为numFeatures,每个string tag对应一个固定的数字, 如果特征多, 则增加特征数
        val hasher: FeatureHasher = new FeatureHasher()
          .setNumFeatures(numFeatures)
          .setInputCols(doubleFeatColumnList ::: strFeatColumns: _*)
          // features作为最后的特征值
          .setOutputCol("features")

        var hashFeatureDFList = List[DataFrame]()
        for (lbCol <- labelColumns) {
            val hashFeatureDF: DataFrame = hasher.transform(dataset).select(featureCol, keyColumns ::: List(lbCol): _*)
              .withColumnRenamed(lbCol, labelCol)
            //   .withColumn(labelCol, col(lbCol)) // 测试用
            hashFeatureDFList :+= hashFeatureDF
        }

        hashFeatureDFList
    }

    /**
     *
     * @param dataset            : 需要normalize的df
     * @param doubleFeatColumn   : 需要normalize的列
     * @param normalizeModelPath : 模型保存路径
     * @param isTraining         : 是否是训练状态
     * @return
     */
    def normalizeDF(dataset: DataFrame,
                    doubleFeatColumn: List[String],
                    normalizeModelPath: String = "standard-model",
                    isTraining: Boolean = true): (DataFrame, List[String]) = {

        val doubleInputColumn = "double_features"
        val doubleOutputColumn = "double_features_normal"
        // 先把double列合并
        val assemblerDataset = new VectorAssembler()
          .setInputCols(doubleFeatColumn.toArray)
          .setOutputCol(doubleInputColumn)
          .transform(dataset)
        assemblerDataset.show(5, truncate = false)

        var model: StandardScalerModel = null
        if (isTraining) {
            // 再将合并列标准化
            model = new StandardScaler()
              .setInputCol(doubleInputColumn)
              .setOutputCol(doubleOutputColumn)
              .setWithMean(true)
              .setWithStd(true)
              .fit(assemblerDataset)
            // 保存模型
            model.write.overwrite.save(normalizeModelPath)
        } else {
            // 加载模型
            model = StandardScalerModel.load(normalizeModelPath)
        }

        val standardDataset = model.transform(assemblerDataset)
        standardDataset.show(5, truncate = false)

        // 再将合并列拆分
        val disassembler = new VectorDisassembler().setInputCol(doubleOutputColumn)
        val disassemblerDataset: DataFrame = disassembler.transform(standardDataset)
        disassemblerDataset.show(5, truncate = false)


        // 设置新的doubleFeatColumnList
        val doubleFeatColumnList: List[String] = doubleFeatColumn.indices.toList
          .map(x => doubleOutputColumn + "_" + x)

        (disassemblerDataset, doubleFeatColumnList)
    }

    // 类型转换函数
    def transformDataType(rdd: RDD[(List[String], mutable.Map[String, String])],
                          doubleFeatColumn: List[String],
                          strFeatColumns: List[String],
                          labelColumns: List[String] = List()): RDD[List[Any]] = {
        // 空值处理
        val data: RDD[List[Any]] = rdd.map(x => {
            var valueList = List[Any]()
            for (c <- strFeatColumns) {
                valueList :+= x._2.getOrElse(c, "NULL")
            }
            for (c <- doubleFeatColumn) {
                val rawValue = x._2.getOrElse(c, "0")
                valueList :+= {
                  if (rawValue.toUpperCase == "NULL") "0" else rawValue
                  }.toDouble
            }
            for (c <- labelColumns) {
                val rawValue = x._2.getOrElse(c, "0")
                valueList :+= {
                  if (rawValue.toUpperCase == "NULL") "0" else rawValue
                  }.toInt
            }
            x._1 ::: valueList
        })
        data
    }

}
